
from kfp import compiler
import kfp.dsl as dsl
from kubernetes import client as k8s_client


class load_dataOp(dsl.ContainerOp):
    """load raw data from tensorflow, do data transform"""

    def __init__(self, data_dir, file_name, if_gray):
        super(load_dataOp, self).__init__(
            name='load_data',
            image='mnist-load_data:v0.0.2',
            command=[
                'python3','load_data.py',
                '--data_dir', data_dir,
                '--file_name', file_name,
                '--if_gray', if_gray
            ],
            file_outputs={
                'data_file': data_dir + 'train_test_data.txt'
            })


class trainOp(dsl.ContainerOp):
    """train keras model"""

    def __init__(self, data_dir, data_file, batch_size, lr, epoches):
        super(trainOp, self).__init__(
            name='train',
            image='mnist-train:v0.0.2',
            command=[
                'python3','train_mnist.py',
                '--data_dir', data_dir,
                '--data_file', data_file,
                '--batch_size', batch_size,
                '--lr', lr,
                '--epoches', epoches,
            ],
            file_outputs={
                'model_path': data_dir + 'model.txt'
            })


class predictOp(dsl.ContainerOp):
    """get predict by trained model"""

    def __init__(self, data_file, model_file, data_dir):
        super(predictOp, self).__init__(
            name='predict',
            image='mnist-predict:v0.0.2',
            command=[
                'python3','predict_mnist.py',
                '--data_file', data_file,
                '--model_path', model_file,
                '--data_dir', data_dir,
            ],
            file_outputs={
                'result_file': data_dir + 'result.txt'
            })


@dsl.pipeline(
    name='MnistStage',
    description='shows how to define dsl.Condition.'
)
def MnistTest(
        file_name,
        if_gray,
        batch_size,
        lr,
        epoches,
):
    data_dir='/nfs/aiflow/'
    load_data = load_dataOp(data_dir, file_name, if_gray).add_volume(k8s_client.V1Volume(name='aiflow',
                                                                                nfs=k8s_client.V1NFSVolumeSource(
                                                                                    path='/nfs/aiflow/',
                                                                                    server='master'))).add_volume_mount(
        k8s_client.V1VolumeMount(mount_path='/nfs/aiflow/', name='aiflow'))

    train = trainOp(data_dir, load_data.outputs['data_file'], batch_size, lr, epoches).add_volume(k8s_client.V1Volume(name='aiflow',
                                                                                             nfs=k8s_client.V1NFSVolumeSource(
                                                                                                 path='/nfs/aiflow/',
                                                                                                 server='master'))).add_volume_mount(
        k8s_client.V1VolumeMount(mount_path='/nfs/aiflow/', name='aiflow'))

    predict = predictOp( load_data.outputs['data_file'], train.outputs['model_path'], data_dir).add_volume(
        k8s_client.V1Volume(name='aiflow',
                            nfs=k8s_client.V1NFSVolumeSource(
                                path='/nfs/aiflow/',
                                server='master'))).add_volume_mount(
        k8s_client.V1VolumeMount(mount_path='/nfs/aiflow/', name='aiflow'))
    dsl.get_pipeline_conf().set_image_pull_secrets([k8s_client.V1ObjectReference(name="aiflow")])


if __name__ == '__main__':
    compiler.Compiler().compile(MnistTest, 'pipeline.yaml')
# run = client.run_pipeline(exp.id, 'wbliu3', 'mnist.tar.gz')
